import socket
import ssl
from urllib.parse import urlparse

import pytest
import redis
from redis.exceptions import ConnectionError, RedisError

from .conftest import skip_if_cryptography, skip_if_nocryptography
from .ssl_utils import CertificateType, get_tls_certificates


@pytest.mark.ssl
class TestSSL:
    """Tests for SSL connections

    This relies on the --redis-ssl-url purely for rebuilding the client
    and connecting to the appropriate port.
    """

    @pytest.fixture(autouse=True)
    def _set_ssl_certs(self, request):
        tls_cert_subdir = request.session.config.REDIS_INFO["tls_cert_subdir"]
        self.client_certs = get_tls_certificates(tls_cert_subdir)
        self.server_certs = get_tls_certificates(
            tls_cert_subdir, cert_type=CertificateType.server
        )

    def test_ssl_with_invalid_cert(self, request):
        ssl_url = request.config.option.redis_ssl_url
        sslclient = redis.from_url(ssl_url)
        with pytest.raises(ConnectionError) as e:
            sslclient.ping()
        assert "SSL: CERTIFICATE_VERIFY_FAILED" in str(e)
        sslclient.close()

    def test_ssl_connection(self, request):
        ssl_url = request.config.option.redis_ssl_url
        p = urlparse(ssl_url)[1].split(":")

        r = redis.Redis(
            host=p[0],
            port=p[1],
            ssl=True,
            ssl_cert_reqs="none",
        )
        assert r.ping()
        r.close()

    def test_ssl_connection_without_ssl(self, request):
        ssl_url = request.config.option.redis_ssl_url
        p = urlparse(ssl_url)[1].split(":")
        r = redis.Redis(host=p[0], port=p[1], ssl=False)

        with pytest.raises(ConnectionError) as e:
            r.ping()
        assert "Connection closed by server" in str(e)
        r.close()

    def test_validating_self_signed_certificate(self, request):
        ssl_url = request.config.option.redis_ssl_url
        p = urlparse(ssl_url)[1].split(":")
        r = redis.Redis(
            host=p[0],
            port=p[1],
            ssl=True,
            ssl_certfile=self.client_certs.certfile,
            ssl_keyfile=self.client_certs.keyfile,
            ssl_cert_reqs="required",
            ssl_ca_certs=self.client_certs.ca_certfile,
        )
        assert r.ping()
        r.close()

    def test_validating_self_signed_string_certificate(self, request):
        with open(self.client_certs.ca_certfile) as f:
            cert_data = f.read()
        ssl_url = request.config.option.redis_ssl_url
        p = urlparse(ssl_url)[1].split(":")
        r = redis.Redis(
            host=p[0],
            port=p[1],
            ssl=True,
            ssl_certfile=self.client_certs.certfile,
            ssl_keyfile=self.client_certs.keyfile,
            ssl_cert_reqs="required",
            ssl_ca_data=cert_data,
        )
        assert r.ping()
        r.close()

    @pytest.mark.parametrize(
        "ssl_ciphers",
        [
            "AES256-SHA:DHE-RSA-AES256-SHA:AES128-SHA:DHE-RSA-AES128-SHA",
            "DHE-RSA-AES256-GCM-SHA384",
            "ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305",
        ],
    )
    def test_ssl_connection_tls12_custom_ciphers(self, request, ssl_ciphers):
        ssl_url = request.config.option.redis_ssl_url
        p = urlparse(ssl_url)[1].split(":")
        r = redis.Redis(
            host=p[0],
            port=p[1],
            ssl=True,
            ssl_cert_reqs="none",
            ssl_min_version=ssl.TLSVersion.TLSv1_3,
            ssl_ciphers=ssl_ciphers,
        )
        assert r.ping()
        r.close()

    def test_ssl_connection_tls12_custom_ciphers_invalid(self, request):
        ssl_url = request.config.option.redis_ssl_url
        p = urlparse(ssl_url)[1].split(":")
        r = redis.Redis(
            host=p[0],
            port=p[1],
            ssl=True,
            ssl_cert_reqs="none",
            ssl_min_version=ssl.TLSVersion.TLSv1_2,
            ssl_ciphers="foo:bar",
        )
        with pytest.raises(RedisError) as e:
            r.ping()
        assert "No cipher can be selected" in str(e)
        r.close()

    @pytest.mark.parametrize(
        "ssl_ciphers",
        [
            "TLS_CHACHA20_POLY1305_SHA256",
            "TLS_AES_256_GCM_SHA384:TLS_CHACHA20_POLY1305_SHA256",
        ],
    )
    def test_ssl_connection_tls13_custom_ciphers(self, request, ssl_ciphers):
        # TLSv1.3 does not support changing the ciphers
        ssl_url = request.config.option.redis_ssl_url
        p = urlparse(ssl_url)[1].split(":")
        r = redis.Redis(
            host=p[0],
            port=p[1],
            ssl=True,
            ssl_cert_reqs="none",
            ssl_min_version=ssl.TLSVersion.TLSv1_2,
            ssl_ciphers=ssl_ciphers,
        )
        with pytest.raises(RedisError) as e:
            r.ping()
        assert "No cipher can be selected" in str(e)
        r.close()

    def _create_oscp_conn(self, request):
        ssl_url = request.config.option.redis_ssl_url
        p = urlparse(ssl_url)[1].split(":")
        r = redis.Redis(
            host=p[0],
            port=p[1],
            ssl=True,
            ssl_certfile=self.client_certs.certfile,
            ssl_keyfile=self.client_certs.keyfile,
            ssl_cert_reqs="required",
            ssl_ca_certs=self.client_certs.ca_certfile,
            ssl_validate_ocsp=True,
        )
        return r

    @skip_if_cryptography()
    def test_ssl_ocsp_called(self, request):
        r = self._create_oscp_conn(request)
        with pytest.raises(RedisError) as e:
            r.ping()
        assert "cryptography is not installed" in str(e)
        r.close()

    @skip_if_nocryptography()
    def test_ssl_ocsp_called_withcrypto(self, request):
        r = self._create_oscp_conn(request)
        with pytest.raises(ConnectionError) as e:
            assert r.ping()
        assert "No AIA information present in ssl certificate" in str(e)
        r.close()

    @skip_if_nocryptography()
    def test_valid_ocsp_cert_http(self):
        from redis.ocsp import OCSPVerifier

        hostnames = ["github.com", "aws.amazon.com", "ynet.co.il"]
        for hostname in hostnames:
            context = ssl.create_default_context()
            with socket.create_connection((hostname, 443)) as sock:
                with context.wrap_socket(sock, server_hostname=hostname) as wrapped:
                    ocsp = OCSPVerifier(wrapped, hostname, 443)
                    assert ocsp.is_valid()

    @skip_if_nocryptography()
    def test_revoked_ocsp_certificate(self):
        from redis.ocsp import OCSPVerifier

        context = ssl.create_default_context()
        hostname = "revoked.badssl.com"
        with socket.create_connection((hostname, 443)) as sock:
            with context.wrap_socket(sock, server_hostname=hostname) as wrapped:
                ocsp = OCSPVerifier(wrapped, hostname, 443)
                with pytest.raises(ConnectionError) as e:
                    assert ocsp.is_valid()
                assert "REVOKED" in str(e)

    @skip_if_nocryptography()
    def test_unauthorized_ocsp(self):
        from redis.ocsp import OCSPVerifier

        context = ssl.create_default_context()
        hostname = "stackoverflow.com"
        with socket.create_connection((hostname, 443)) as sock:
            with context.wrap_socket(sock, server_hostname=hostname) as wrapped:
                ocsp = OCSPVerifier(wrapped, hostname, 443)
                with pytest.raises(ConnectionError):
                    ocsp.is_valid()

    @skip_if_nocryptography()
    def test_ocsp_not_present_in_response(self):
        from redis.ocsp import OCSPVerifier

        context = ssl.create_default_context()
        hostname = "google.co.il"
        with socket.create_connection((hostname, 443)) as sock:
            with context.wrap_socket(sock, server_hostname=hostname) as wrapped:
                ocsp = OCSPVerifier(wrapped, hostname, 443)
                with pytest.raises(ConnectionError) as e:
                    assert ocsp.is_valid()
                assert "from the" in str(e)

    @skip_if_nocryptography()
    def test_unauthorized_then_direct(self):
        from redis.ocsp import OCSPVerifier

        # these certificates on the socket end return unauthorized
        # then the second call succeeds
        hostnames = ["wikipedia.org", "squarespace.com"]
        for hostname in hostnames:
            context = ssl.create_default_context()
            with socket.create_connection((hostname, 443)) as sock:
                with context.wrap_socket(sock, server_hostname=hostname) as wrapped:
                    ocsp = OCSPVerifier(wrapped, hostname, 443)
                    assert ocsp.is_valid()

    @skip_if_nocryptography()
    def test_mock_ocsp_staple(self, request):
        import OpenSSL

        ssl_url = request.config.option.redis_ssl_url
        p = urlparse(ssl_url)[1].split(":")
        r = redis.Redis(
            host=p[0],
            port=p[1],
            ssl=True,
            ssl_certfile=self.client_certs.cert,
            ssl_keyfile=self.client_certs.keyfile,
            ssl_cert_reqs="required",
            ssl_ca_certs=self.client_certs.ca_certfile,
            ssl_validate_ocsp=True,
            ssl_ocsp_context=p,  # just needs to not be none
        )

        with pytest.raises(RedisError):
            r.ping()
        r.close()

        ctx = OpenSSL.SSL.Context(OpenSSL.SSL.SSLv23_METHOD)
        ctx.use_certificate_file(self.client_certs.cert)
        ctx.use_privatekey_file(self.client_certs.keyfile)

        r = redis.Redis(
            host=p[0],
            port=p[1],
            ssl=True,
            ssl_certfile=self.client_certs.cert,
            ssl_keyfile=self.client_certs.keyfile,
            ssl_cert_reqs="required",
            ssl_ca_certs=self.client_certs.ca_certfile,
            ssl_ocsp_context=ctx,
            ssl_ocsp_expected_cert=open(self.server_certs.ca_certfile, "rb").read(),
            ssl_validate_ocsp_stapled=True,
        )

        with pytest.raises(ConnectionError) as e:
            r.ping()
        assert "no ocsp response present" in str(e)
        r.close()

        r = redis.Redis(
            host=p[0],
            port=p[1],
            ssl=True,
            ssl_certfile=self.client_certs.cert,
            ssl_keyfile=self.client_certs.keyfile,
            ssl_cert_reqs="required",
            ssl_ca_certs=self.client_certs.ca_certfile,
            ssl_validate_ocsp_stapled=True,
        )

        with pytest.raises(ConnectionError) as e:
            r.ping()
        assert "no ocsp response present" in str(e)
        r.close()

    def test_cert_reqs_none_with_check_hostname(self, request):
        """Test that when ssl_cert_reqs=none is used with ssl_check_hostname=True,
        the connection is created successfully with check_hostname internally set to False"""
        ssl_url = request.config.option.redis_ssl_url
        parsed_url = urlparse(ssl_url)
        r = redis.Redis(
            host=parsed_url.hostname,
            port=parsed_url.port,
            ssl=True,
            ssl_cert_reqs="none",
            # Check that ssl_check_hostname is ignored, when ssl_cert_reqs=none
            ssl_check_hostname=True,
        )
        try:
            # Connection should be successful
            assert r.ping()
            # check_hostname should have been automatically set to False
            assert r.connection_pool.connection_class == redis.SSLConnection
            conn = r.connection_pool.make_connection()
            assert conn.check_hostname is False
        finally:
            r.close()

    def test_ssl_verify_flags_applied_to_context(self, request):
        """
        Test that ssl_include_verify_flags and ssl_exclude_verify_flags
        are properly applied to the SSL context
        """
        ssl_url = request.config.option.redis_ssl_url
        parsed_url = urlparse(ssl_url)

        # Test with specific SSL verify flags
        ssl_include_verify_flags = [
            ssl.VerifyFlags.VERIFY_CRL_CHECK_LEAF,  # Disable strict verification
            ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN,  # Enable partial chain
        ]

        ssl_exclude_verify_flags = [
            ssl.VerifyFlags.VERIFY_X509_STRICT,  # Disable trusted first
        ]

        r = redis.Redis(
            host=parsed_url.hostname,
            port=parsed_url.port,
            ssl=True,
            ssl_cert_reqs="none",
            ssl_include_verify_flags=ssl_include_verify_flags,
            ssl_exclude_verify_flags=ssl_exclude_verify_flags,
        )

        try:
            # Get the connection to trigger SSL context creation
            conn = r.connection_pool.get_connection()
            assert isinstance(conn, redis.SSLConnection)

            # Verify the flags were processed by checking they're stored in connection
            assert conn.ssl_include_verify_flags is not None
            assert len(conn.ssl_include_verify_flags) == 2

            assert conn.ssl_exclude_verify_flags is not None
            assert len(conn.ssl_exclude_verify_flags) == 1

            # Check each flag individually
            for flag in ssl_include_verify_flags:
                assert flag in conn.ssl_include_verify_flags, (
                    f"Flag {flag} not found in stored ssl_include_verify_flags"
                )
            for flag in ssl_exclude_verify_flags:
                assert flag in conn.ssl_exclude_verify_flags, (
                    f"Flag {flag} not found in stored ssl_exclude_verify_flags"
                )

            # Test the actual SSL context created by the connection
            # We need to create a mock socket and call _wrap_socket_with_ssl to get the context
            import socket
            import unittest.mock

            mock_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

            try:
                # Mock the wrap_socket method to capture the context
                captured_context = None

                def capture_context_wrap_socket(context_self, sock, **_kwargs):
                    nonlocal captured_context
                    captured_context = context_self
                    # Don't actually wrap the socket, just return the original socket
                    # to avoid connection errors
                    return sock

                with unittest.mock.patch.object(
                    ssl.SSLContext, "wrap_socket", capture_context_wrap_socket
                ):
                    try:
                        conn._wrap_socket_with_ssl(mock_sock)
                    except Exception:
                        # We expect this to potentially fail since we're not actually connecting
                        # but we should have captured the context
                        pass

                # Validate that we captured a context and it has the correct flags applied
                assert captured_context is not None, "SSL context was not captured"

                # Verify that VERIFY_X509_STRICT was disabled (bit cleared)
                assert not (
                    captured_context.verify_flags & ssl.VerifyFlags.VERIFY_X509_STRICT
                ), "VERIFY_X509_STRICT should be disabled but is enabled"

                # Verify that VERIFY_CRL_CHECK_CHAIN was enabled (bit set)
                assert (
                    captured_context.verify_flags
                    & ssl.VerifyFlags.VERIFY_CRL_CHECK_CHAIN
                ), "VERIFY_CRL_CHECK_CHAIN should be enabled but is disabled"

            finally:
                mock_sock.close()

        finally:
            r.close()
